import os
import fal_client
import folder_paths
import configparser
import base64
import io
from PIL import Image
import logging
import json
import requests
import numpy as np
import torch

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

class BaseFalAPIFluxNode:
    def __init__(self):
        self.api_key = self.get_api_key()
        os.environ['FAL_KEY'] = self.api_key
        self.api_endpoint = None

    def get_api_key(self):
        config = configparser.ConfigParser()
        config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'config.ini')
        if os.path.exists(config_path):
            config.read(config_path)
            return config.get('falai', 'api_key', fallback=None)
        return None
    
    def set_api_endpoint(self, endpoint):
        self.api_endpoint = endpoint

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True}),
                "width": ("INT", {"default": 1024, "step": 8}),
                "height": ("INT", {"default": 1024, "step": 8}),
                "num_inference_steps": ("INT", {"default": 28, "min": 1, "max": 100}),
                "guidance_scale": ("FLOAT", {"default": 3.5, "min": 0.1, "max": 40.0}),
                "num_images": ("INT", {"default": 1, "min": 1, "max": 4}),
                "enable_safety_checker": ("BOOLEAN", {"default": True}),
            },
            "optional": {
                "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
            }
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "generate"
    CATEGORY = "image generation"

    def prepare_arguments(self, prompt, width, height, num_inference_steps, guidance_scale, num_images, enable_safety_checker, seed=None, **kwargs):
        if not self.api_key:
            raise ValueError("API key is not set. Please check your config.ini file.")

        arguments = {
            "prompt": prompt,
            "num_inference_steps": num_inference_steps,
            "guidance_scale": guidance_scale,
            "num_images": num_images,
            "enable_safety_checker": enable_safety_checker
        }

        # Handle custom image size
        if width is None or height is None:
            raise ValueError("Width and height must be provided when using custom image size")
        arguments["image_size"] = {
            "width": width,
            "height": height
        }

        if seed is not None and seed != 0:
            arguments["seed"] = seed

        return arguments

    def call_api(self, arguments):
        logger.debug(f"Full API request payload: {json.dumps(arguments, indent=2)}")
        
        if not self.api_endpoint:
            raise ValueError("API endpoint is not set. Please set it using set_api_endpoint() method.")

        try:
            handler = fal_client.submit(
                self.api_endpoint,
                arguments=arguments,
            )
            result = handler.get()
            logger.debug(f"API response: {json.dumps(result, indent=2)}")
            return result
        except Exception as e:
            logger.error(f"API error details: {str(e)}")
            if hasattr(e, 'response'):
                logger.error(f"API error response: {e.response.text}")
            raise RuntimeError(f"An error occurred when calling the fal.ai API: {str(e)}") from e

    def process_images(self, result):
        if "images" not in result or not result["images"]:
            logger.error("No images were generated by the API.")
            raise RuntimeError("No images were generated by the API.")

        output_images = []
        for index, img_info in enumerate(result["images"]):
            try:
                logger.debug(f"Processing image {index}: {json.dumps(img_info, indent=2)}")
                if not isinstance(img_info, dict) or "url" not in img_info or not img_info["url"]:
                    logger.error(f"Invalid image info for image {index}")
                    continue
                
                img_url = img_info["url"]
                logger.debug(f"Image URL: {img_url[:100]}...")  # Log the first 100 characters of the URL

                if img_url.startswith("data:image"):
                    # Handle Base64 encoded image
                    try:
                        _, img_data = img_url.split(",", 1)
                        img_data = base64.b64decode(img_data)
                    except ValueError:
                        logger.error(f"Failed to split image URL for image {index}")
                        continue
                else:
                    # Handle regular URL
                    try:
                        response = requests.get(img_url)
                        response.raise_for_status()
                        img_data = response.content
                    except requests.RequestException as e:
                        logger.error(f"Failed to download image from URL for image {index}: {str(e)}")
                        continue

                # Log the first few bytes of the image data
                logger.debug(f"First 20 bytes of image data: {img_data[:20]}")

                # Try to interpret the data as an image
                try:
                    img = Image.open(io.BytesIO(img_data))
                    logger.debug(f"Opened image with size: {img.size} and mode: {img.mode}")
                except Exception as e:
                    logger.error(f"Failed to open image data: {str(e)}")
                    # If opening as an image fails, try to interpret it as raw pixel data
                    img_np = np.frombuffer(img_data, dtype=np.uint8)
                    logger.debug(f"Interpreted as raw pixel data with shape: {img_np.shape}")
                    
                    # If the shape is (1024,), reshape it to a more sensible image size
                    if img_np.shape == (1024,):
                        img_np = img_np.reshape(32, 32)  # Reshape to 32x32 image
                    elif img_np.shape == (1, 1, 1024):
                        img_np = img_np.reshape(32, 32)
                    
                    # Normalize the data to 0-255 range
                    img_np = ((img_np - img_np.min()) / (img_np.max() - img_np.min()) * 255).astype(np.uint8)
                    
                    img = Image.fromarray(img_np, 'L')  # Create grayscale image
                    img = img.convert('RGB')  # Convert to RGB
                
                # Ensure image is in RGB mode
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                
                # Convert PIL Image to NumPy array
                img_np = np.array(img).astype(np.float32) / 255.0
                
                # Create tensor with batch dimension (1, H, W, C)
                img_tensor = torch.from_numpy(img_np)
                img_tensor = img_tensor.unsqueeze(0)  # (1, H, W, C)
                
                output_images.append(img_tensor)
            except Exception as e:
                logger.error(f"Failed to process image {index}: {str(e)}")

        if not output_images:
            logger.error("Failed to process any of the generated images.")
            raise RuntimeError("Failed to process any of the generated images.")

        # Stack all images into a single batch tensor
        if output_images:
            output_tensor = torch.cat(output_images, dim=0)
            logger.debug(f"Returning batched tensor with shape: {output_tensor.shape}")
            return [output_tensor] 
        else:
            logger.error("No images were successfully processed")
            raise RuntimeError("No images were successfully processed")
    
    def upload_image(self, image):
        try:
            # Convert PyTorch tensor to numpy array
            if isinstance(image, torch.Tensor):
                image = image.cpu().numpy()

            # Handle different shapes of numpy arrays
            if isinstance(image, np.ndarray):
                if image.ndim == 4 and image.shape[0] == 1:  # (1, H, W, 3) or (1, H, W, 1)
                    image = image.squeeze(0)
                
                if image.ndim == 3:
                    if image.shape[2] == 3:  # (H, W, 3) RGB image
                        pass
                    elif image.shape[2] == 1:  # (H, W, 1) grayscale
                        image = np.repeat(image, 3, axis=2)
                    elif image.shape[0] == 3:  # (3, H, W) RGB
                        image = np.transpose(image, (1, 2, 0))
                    elif image.shape[0] == 1:  # (1, H, W) grayscale
                        image = np.repeat(image.squeeze(0)[..., np.newaxis], 3, axis=2)
                elif image.shape == (1, 1, 1536):  # Special case for (1, 1, 1536) shape
                    image = image.reshape(32, 48)
                    image = np.repeat(image[..., np.newaxis], 3, axis=2)
                else:
                    raise ValueError(f"Unsupported image shape: {image.shape}")

                # Normalize to 0-255 range if not already
                if image.dtype != np.uint8:
                    image = (image - image.min()) / (image.max() - image.min()) * 255
                    image = image.astype(np.uint8)

                image = Image.fromarray(image)

            # Ensure image is in RGB mode
            if image.mode != 'RGB':
                image = image.convert('RGB')

            # Resize image if it's too large (optional, adjust max_size as needed)
            max_size = 1024  # Example max size
            if max(image.size) > max_size:
                image.thumbnail((max_size, max_size), Image.LANCZOS)

            # Convert PIL Image to bytes
            buffered = io.BytesIO()
            image.save(buffered, format="PNG")
            img_byte = buffered.getvalue()

            # Upload the image using fal_client
            url = fal_client.upload(img_byte, "image/png")
            logger.info(f"Image uploaded successfully. URL: {url}")
            return url
        except Exception as e:
            logger.error(f"Failed to process or upload image: {str(e)}")
            raise

    def generate(self, **kwargs):
        arguments = self.prepare_arguments(**kwargs)
        result = self.call_api(arguments)
        output_images = self.process_images(result)
        return tuple(output_images)
